import colorsys
import math
import os
import pandas as pd
from pathlib import Path

import cv2
import numpy as np
from matplotlib import colors
from matplotlib import pyplot as plt
from read_roi import read_roi_zip, read_roi_file
from scipy.ndimage import find_objects, binary_fill_holes, binary_erosion
from skimage import io
from skimage.filters import median, threshold_multiotsu, threshold_li
from skimage.morphology import square, closing, opening, remove_small_holes
from itertools import combinations

import img_util.threshold


def rgb_to_hsv(arr):
    rgb_to_hsv_channels = np.vectorize(colorsys.rgb_to_hsv)
    r, g, b = np.moveaxis(arr, -1, 0)
    h, s, v = rgb_to_hsv_channels(r, g, b)
    hsv = np.stack((h, s, v), axis=-1)
    return hsv


def hsv_to_rgb(arr):
    hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb)
    h, s, v = np.moveaxis(arr, -1, 0)
    r, g, b = hsv_to_rgb_channels(h, s, v)
    rgb = np.stack((r, g, b), axis=-1)
    return rgb


def get_roi(roi_path):
    """
    Read individual ROIs from a path pointing to a ROI file from ImageJ (.roi or .zip)

    Parameters
    ----------
    roi_path: string
        a path pointing to a ROI file from ImageJ (extension can be ".roi" for single ROI,
        or ".zip" for multiple ROIs)

    Returns
    -------
    roi_value: list
        list of ROI values (same order as in roi_key)
    roi_key: list
        list of ROI names
    """

    path_ext = Path(roi_path).suffix  # pull out file extension
    if path_ext == ".roi":
        roi = read_roi_file(roi_path)  # single ROI
    elif path_ext == ".zip":
        roi = read_roi_zip(roi_path)  # multiple ROIs
    else:
        raise ValueError("Error: path of ROI must be in .roi or .zip format")

    roi_key = []
    roi_value = []
    for key, value in roi.items():
        roi_key.append(key)
        roi_value.append(value)

    return roi_value, roi_key


def masks_to_outlines(masks):
    """
    Get the outlines of each mask (algorithm from cellpose)

    Parameters
    ----------
    masks: ndarray
        2D array of masks

    Returns
    -------
    outlines:
        2D array of outlines of masks

    """

    masks = masks.copy()

    if masks.ndim != 2:
        raise ValueError("Number of dimensions for masks must be 2, not %d" % masks.ndim)

    outlines = np.zeros(masks.shape, np.uint16 if masks.max() < 2 ** 16 - 1 else np.uint32)
    slices = find_objects(masks.astype(int))

    for i, si in enumerate(slices):

        # The findContours method may be overkill. Alternatively can subtract out eroded mask to get outline
        if si is not None:
            sr, sc = si
            mask_i = (masks[sr, sc] == (i + 1)).astype(np.uint8)
            contours = cv2.findContours(mask_i, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            pc, pr = np.concatenate(contours[-2], axis=0).squeeze().T  # outline pixel col and row
            out_r, out_c = pr + sr.start, pc + sc.start
            outlines[out_r, out_c] = i + 1

    return outlines


def rois_to_masks(roi_path, mask_arr_shape, fill_roi=True):
    """
    Turn ROIs from an ImageJ .roi file into an array of masks

    Parameters
    ----------
    roi_path: str
        Path to ROI file (must be .roi extension)
    mask_arr_shape: tuple
        2D shape of mask array
    fill_roi: bool
        True=ROIs are filled. False=ROIs are outlined.

    Returns
    -------
    roi_mask: ndarray
        2D array of boolean of ROI masks
    """

    if os.path.splitext(os.path.basename(roi_path))[1] != '.roi':
        raise ValueError("ROI needs to be a .roi file")

    roi_mask = np.zeros(mask_arr_shape, np.uint8)
    roi_mask_c = roi_mask.copy()
    roi_all, _ = get_roi(roi_path)
    roi_all = list(roi_all)[0]  # retrieve ROIs as a dictionary

    if roi_all['type'] == 'composite':

        for roi_i in roi_all['paths']:

            roi_pt = []  # container for ROI's corner points
            for roi_ol in roi_i:
                roi = [math.floor(x) for x in roi_ol]  # turn into integer to use as index downstream
                roi_pt.append(roi)

            for k, roi in enumerate(roi_pt):

                # for each ROI, get also the ROI immediately following
                if k + 1 < len(roi_pt):
                    roi_1, roi_2 = roi_pt[k], roi_pt[k + 1]
                else:
                    roi_1, roi_2 = roi_pt[k], roi_pt[0]

                c1, r1 = roi_1
                c2, r2 = roi_2

                roi_line = cv2.line(roi_mask_c, (c1, r1), (c2, r2), 1)
                roi_mask = roi_mask | roi_line

    elif roi_all['type'] == 'rectangle':

        l, t, w, h = roi_all['left'], roi_all['top'], roi_all['width'], roi_all['height']
        roi_mask = cv2.rectangle(roi_mask_c, (l, t), (l + w, t + h), 1)

    else:
        raise ValueError("ROI must be of type 'rectangle' or 'composite', not '%s'"
                         % roi_all['type'])

    roi_mask = roi_mask > 0  # turn masks into boolean
    if fill_roi:
        roi_mask = binary_fill_holes(roi_mask)

    return roi_mask


def add_bg_mask_to_img(img, mask, mask_colour, mask_fill_alpha):
    """
    Add a coloured outline and fill of a binary mask (e.g., background ROI) to an image.

    Parameters
    ----------
    img: ndarray
        a 2D grayscale array, or a 3D RGB array
    mask: ndarray of bool
        a 2D boolean array, same 2D shape as img
    mask_colour: str
        a string of named colour for matplotlib's "colors" module
    mask_fill_alpha: int or float
        a scalar for scaling mask fill transparency (0=transparent, 1=solid colour)

    Returns
    -------
    img_rgb: ndarray
        a 3D RGB array of image with background mask added

    """

    if img.ndim == 3:
        img_rgb = img.copy().astype(np.uint32)
    elif img.ndim == 2:
        img_rgb = img.copy()[..., np.newaxis] * np.ones(3, dtype=np.uint32)  # broadcast grayscale into RGB
    else:
        raise ValueError("Image must be 2D array (grayscale) or 3D array (RGB)")

    # Add background colour to image only if 'mask_colour' is specified
    if mask_colour is not None:
        if (img.shape[0], img.shape[1]) != (mask.shape[0], mask.shape[1]):
            raise ValueError("The first 2D shape of img and mask must match")

        mask_outline = mask ^ binary_erosion(mask, border_value=1)  # bitwise XOR to subtract eroded mask to get outline
        mask_rgb = np.asarray(colors.to_rgb(mask_colour)) * 255  # retrieve mask colour in RGB

        # loop through each RGB channel to add in mask outline and fill
        for img_chl, mask_rgb_chl in zip(np.moveaxis(img_rgb, -1, 0), mask_rgb):
            img_chl[mask] = img_chl[mask] + mask_rgb_chl * mask_fill_alpha  # mask fill
            img_chl[mask_outline] = mask_rgb_chl  # mask outline

        img_rgb = np.clip(img_rgb, 0, 255).astype(np.uint8)  # clip at 255 for output 8-bit RGB image

    return img_rgb


def threshold_background_mask(bg_img, rm_holes_area, open_close_size=3, med_filter_size=None, rm_small_bg_area=None):
    """
    Create a mask for bright background, using Li's thresholding method. Holes inside mask are filled (optional).

    Parameters
    ----------
    bg_img: ndarray
        2D image array containing background.
    rm_holes_area: int or float
        fill holes smaller than this size (in pixels).
    open_close_size: int
        size of structure element for performing morphological opening then closing (for smoothing edges). Default=3.
    med_filter_size: int or None
        size of structure element of median filter of input image
    rm_small_bg_area: int or float or None
        remove isolated background area smaller than this size (in pixels).

    Returns
    -------
    bg_mask: ndarray of bool
        2D boolean mask of thresholded bridge background (invert for dark background)
    """

    # median filter to remove noise
    if med_filter_size is not None:
        bg_img = median(bg_img.copy(), square(med_filter_size))

    thresh_val = threshold_li(bg_img)  # threshold using Li's method
    bg_mask = np.zeros(bg_img.shape, np.bool)
    bg_mask[bg_img >= thresh_val] = True
    bg_mask = remove_small_holes(bg_mask, area_threshold=rm_holes_area)
    bg_mask = opening(bg_mask, square(open_close_size))
    bg_mask = closing(bg_mask, square(open_close_size))

    if rm_small_bg_area is not None:
        bg_mask = ~bg_mask  # turn small isolated backgrounds into holes
        bg_mask = remove_small_holes(bg_mask, area_threshold=rm_small_bg_area)
        bg_mask = ~bg_mask

    return bg_mask


def binary_masks_to_colour(masks, colour='red'):
    """
    Turn binary masks to RGB coloured masks

    Parameters
    ----------
    masks: ndarray of bool
        2D boolean array of masks
    colour: str or None
        Optional: colour of masks. Default='red'. None = random colour for each mask.

    Returns
    -------
    masks_rgb: ndarray of int
        3D image array of masks (2D image in RGB)
    """

    masks_hsv = np.zeros((masks.shape[0], masks.shape[1], 3), np.float32)
    mask_idx = np.unique(masks)  # retrieve masks' ID
    mask_idx = np.delete(mask_idx, mask_idx == 0)  # remove 0 (non-mask) from mask index

    np.random.seed(123)
    for i in mask_idx:
        ipix = (masks == i).nonzero()  # index of mask pixel's row, col as tuples
        if colour is None:
            masks_hsv[ipix[0], ipix[1], 0] = np.random.rand()  # randomize hue
            masks_hsv[ipix[0], ipix[1], -2:] = 1  # saturation and value set to 1
        else:
            masks_hsv[ipix[0], ipix[1], :] = colors.rgb_to_hsv(colors.to_rgb(colour))

    masks_rgb = (hsv_to_rgb(masks_hsv) * 255).astype(np.uint8)

    return masks_rgb


def add_colour_mask_to_img(img, masks):
    """
    Add RGB coloured masks to an image. Image can be grayscale or RGB.

    Parameters
    ----------
    img: ndarray
        2D array of grayscale image, or 3D array of RGB image
    masks: ndarray
        3D array of masks in RGB

    Returns
    -------
    img_rgb: ndarray
        3D array of image with masks, in RGB
    """

    if img.ndim == 3:
        img_rgb = img.copy()
    elif img.ndim == 2:
        img_rgb = img.copy()[..., np.newaxis] * np.ones(3, dtype=np.uint8)  # broadcast grayscale into RGB
    else:
        raise ValueError("Image must be 2D array (grayscale) or 3D array (RGB)")

    if masks.ndim != 3:
        raise ValueError("Masks must be 3D array (RGB)")

    masks_c = masks.copy()
    masks_id = (np.sum(abs(masks_c), axis=2) != 0)  # 2D boolean index of where masks are (i.e., sum along RGB > 0)
    masks_id = masks_id[..., np.newaxis] * np.ones(3, dtype=np.bool)  # broadcast to 3D
    img_rgb[masks_id] = masks_c[masks_id]

    return img_rgb


def adaptive_threshold_masks_old(img, masks, diameter, mask_edge_pix, mask_ol_cent_fact,
                             min_size_factor=1, backgnd_int=5):
    """
    Adaptive threshold of masks based on original image. Each mask is checked against: 1) a size
    threshold; 2) a putative soma intensity threshold based on multi-class otsu threshold; 3) the
    intensity at the mask's centre is greater than intensity at the mask's edge.

    Parameters
    ----------
    img: ndarray of int
        2D image array in grayscale
    masks: ndarray of int
        2D array of masks
    diameter: float
        Minimum diameter of cells in pixels
    mask_edge_pix: int
        Scalar: number of pixels to include as masks' outline
    mask_ol_cent_fact: float
        Scalar: grayscale intensity mask's center must be greater than mask's edge by this factor
    min_size_factor: float
        Scalar: factor for scaling minimum cell size based on diameter
    backgnd_int: int
        Scalar: grayscale intensity of background

    Returns
    -------
    masks_o, mask_avg_int, mask_npix:
        2D array of thresholded masks, list of average intensity for each mask, list of number of pixels
        for each mask

    """

    img_c = img.copy()
    masks_c = masks.copy()

    if img_c.ndim != 2:
        raise ValueError("Image must be a 2D array.")
    if masks_c.ndim != 2:
        raise ValueError("Masks must be a 2D array.")

    size_threshold = math.floor(min_size_factor * np.pi * (diameter / 2) ** 2)  # calculate mask pixel size threshold
    # soma_thresh = threshold_li(img_c[img_c > backgnd_int])  # threshold for extracting putative soma
    trimmed_img = img_c[(img_c > backgnd_int) & (img_c < 255 - backgnd_int)]
    soma_thresh = threshold_multiotsu(trimmed_img, classes=3)  # threshold for extracting putative soma
    print("multi-otsu thresholds are: " + str(soma_thresh))
    # plt.hist(trimmed_img, bins=255)
    soma_thresh = soma_thresh[0]  # use the first threshold as soma threshold

    masks_o = masks_c.copy()  # copy masks for output
    mask_avg_int = []
    mask_npix = []

    mask_idx = np.unique(masks_c)  # retrieve masks' ID
    mask_idx = np.delete(mask_idx, mask_idx == 0)  # remove 0 (non-mask) from mask index

    k_ok = 1  # initiate mask counter

    for k in mask_idx:

        mask = (masks_c == k)
        mask_size = np.count_nonzero(mask)
        img_val_mask_avg = np.average(img_c[mask])

        # bitwise XOR to subtract to get eroded outline
        mask_outline = mask ^ binary_erosion(mask, iterations=mask_edge_pix)
        mask_centre = mask ^ mask_outline
        img_val_mask_ol = np.average(img_c[mask_outline])
        if mask_centre.any():
            img_val_mask_cent = np.average(img_c[mask_centre])
        else:
            img_val_mask_cent = img_val_mask_avg

        # If current mask is smaller than size threshold, or if the image
        # value is lower than putative soma threshold, then set mask to 0 (remove)
        if mask_size < size_threshold or img_val_mask_avg < soma_thresh:

            masks_o[mask] = 0
            # print("Rejected#1 %d: size=%d, avg int=%d, In=%d,Out=%d"
            #       % (k, mask_size, img_val_mask_avg, img_val_mask_cent, img_val_mask_ol))

        # Accept mask if image intensity at mask centre is bigger than at mask outline
        elif img_val_mask_cent > (img_val_mask_ol * mask_ol_cent_fact):

            masks_o[mask] = k_ok  # update original mask counter
            k_ok = k_ok + 1  # increment mask counter
            mask_avg_int.append(img_val_mask_avg)
            mask_npix.append(mask_size)
            # print("Accepted %d: size=%d, avg int=%d, In=%d,Out=%d"
            #       % (k, mask_size, img_val_mask_avg, img_val_mask_cent, img_val_mask_ol))

        else:
            masks_o[mask] = 0
            # print("Rejected#2 %d: size=%d, avg int=%d, In=%d,Out=%d"
            #       % (k, mask_size, img_val_mask_avg, img_val_mask_cent, img_val_mask_ol))

    return masks_o, mask_avg_int, mask_npix


def add_npmask_to_img_via_path(mask_path, img_path):
    """
    Add coloured mask outlines from a .npy file to an image

    Parameters
    ----------
    mask_path: str
        Path to mask file. Must be numpy (.npy) save file.
    img_path: str
        Path to image file.

    Returns
    -------
    img_w_mask: ndarray
    2D array of RGB image with mask outlines

    """

    mask = np.load(mask_path, allow_pickle=True)
    img = io.imread(img_path)
    mask = binary_masks_to_colour(masks_to_outlines(mask))
    img_w_mask = add_colour_mask_to_img(img, mask)

    return img_w_mask


def get_mask_overlap(channel_mask_dict, df_channel, overlap_channel):
    """
    Return thresholded overlapping masks. Thresholding rule: for an overlapping pixel, find the corresponding masks in
    each channel, and check fraction overlap is above threshold for mask with smallest original size. This allows a
    small mask that is completely enclosed in a mask in another channel to still be allowed.

    Parameters
    ----------
    channel_mask_dict: dict
        Dictionary containing thresholded masks for each channel
    df_channel: dataframe
        Pandas dataframe containing overlap threshold for each channel
    overlap_channel: list of str
        List of string, of channels to retrieve overlapping masks

    Returns
    -------
    final_overlap: ndarray
        2D array of overlapping mask as boolean, after thresholding
    nmask_ovlap:
        List of unique overlapping masks per channel
    overlap_orig:
        2D array of original overlapping mask as boolean, before thresholding
    """

    mask_ch = []  # masks for each channel
    channel_ovlap_thresh = []

    for ch in overlap_channel:
        mask_arr = channel_mask_dict[ch]['thresholded masks'][0].copy()  # retrieve masks array
        mask_ch.append(mask_arr)
        channel_ovlap_thresh.append(df_channel.loc[ch]['overlap threshold'])

    # Get pixels with overlapping masks
    mask1_shape = mask_ch[0].shape
    overlap_remain = (mask_ch[0] != 0)  # array for mask-overlap pixel
    for m in mask_ch[1:]:
        if m.shape != mask1_shape:
            raise ValueError("Mask array shapes are not all equal.")
        overlap_remain = overlap_remain & (m != 0)  # check overlap across channels
    overlap_orig = overlap_remain.copy()
    final_overlap = np.zeros(mask1_shape, dtype=bool)  # initialize final thresholded overlap masks to false

    while np.any(overlap_remain):

        # For each overlapping pixel, retrieve masks across channels, then check mask overlap
        x, y = overlap_remain.nonzero()
        mask_pro = []  # proposed mask overlap
        mask_orig_npx = []  # size of mask being proposed
        for m in mask_ch:
            mask_pro_ch = np.zeros(mask1_shape, dtype=bool)
            m_i = m[x[0], y[0]]  # mask ID for 1st proposed overlapping pixel
            mask_pro_ch[m == m_i] = True  # retrieve all pixels belonging to original mask
            mask_pro.append(mask_pro_ch)
            mask_orig_npx.append(mask_pro_ch.sum())

        # Find mask overlap as fraction of original mask
        mask_ovlap = mask_pro[0]
        for m in mask_pro[1:]:
            mask_ovlap = (mask_ovlap & m)  # find mask overlap pixels across channels
        mask_ovlap_npx = mask_ovlap.sum()
        mask_ovlap_fract = mask_ovlap_npx / mask_orig_npx
        smallest_mask = np.argmin(mask_orig_npx)

        # If the fraction overlap for the smallest original mask is above its channel threshold,
        # then accept mask overlap
        if mask_ovlap_fract[smallest_mask] >= channel_ovlap_thresh[smallest_mask]:
            final_overlap[mask_ovlap] = True
        overlap_remain[mask_ovlap] = False  # update processed masks

    nmask_ovlap = []  # Number of unique overlapping mask per channel
    for m in mask_ch:
        m_ovlap = m.copy()
        m_ovlap[~final_overlap] = 0
        mask_unique = np.unique(m_ovlap)
        nmask_ovlap.append(mask_unique[1:].shape[0])

    return final_overlap, nmask_ovlap, overlap_orig


def normalize_img(img, ptile):
    x = img.copy()
    x = x / np.percentile(x, 100 - ptile)
    x = np.clip(x, 0, 1) * 255
    # x = (x - np.percentile(x, ptile)) / (np.percentile(x, 100-ptile) - np.percentile(x, ptile))
    return x


def quant_channels_mask_overlap(process_channels, channel_mask_dict, df_channel, channel_orig_img,
                                save_mask_overlap_img=False, img_save_dir=None, sample_i=None):
    """
    Wrapper function: quantifies mask overlap across channels

    Parameters
    ----------
    process_channels: list of str
        List of channels to quantify mask overlap.
    channel_mask_dict: dict
        A dictionary containing 'thresholded masks' for each channel.
    df_channel: dataframe
        Pandas dataframe containing 'overlap threshold' parameter for each channel.
    channel_orig_img: dict
        A dictionary containing each channel's original 2D grayscale image.
    save_mask_overlap_img: bool
        If True, saves original image overlain with mask overlap. Default=False.
    img_save_dir: str
        Directory to save mask overlap image to.
    sample_i: str
        Current sample ID, to be included in output saved image filename.

    Returns
    -------
    df_overlap: dataframe
        A Pandas dataframe of df_sample updated with mask overlap quantification for each channel combination.
    """

    df_overlap = pd.DataFrame({'sample_name': [sample_i]})

    # Quantify mask overlaps, for every channel combinations >= 2 channels
    for comb_r in range(2, len(process_channels) + 1):
        for comb in combinations(process_channels, comb_r):
            channel_comb = list(comb)
            channel_comb.sort()
            mask_ovlap, nmask_ovlap, orig_ovlap = get_mask_overlap(channel_mask_dict, df_channel, channel_comb)
            mask_ovlap_outline = masks_to_outlines(mask_ovlap)
            mask_ovlap_outline = binary_masks_to_colour(mask_ovlap_outline, colour='red')
            n_subfig = len(channel_comb) + 1
            fig, ax = plt.subplots(nrows=1, ncols=n_subfig,
                                   figsize=tuple(np.array([6.4, 4.8]) * n_subfig))

            # Update dataframe with quantification, and add subplot to figure
            for i, ch, nm_ovlap in zip(range(len(channel_comb)), channel_comb, nmask_ovlap):
                col_name = ch + ':[' + '+'.join(channel_comb) + ']'
                df_overlap[col_name] = nm_ovlap  # update quantification dataframe
                img_ovlap_ch = channel_orig_img[ch].copy()  # retrieve original image for channel
                img_ovlap_ch = add_colour_mask_to_img(img_ovlap_ch, mask_ovlap_outline)
                ax[i + 1].imshow(img_ovlap_ch)
                ax[i + 1].set_title(col_name)

            # Plot thresholded mask overlap vs. rejected mask overlap
            img_ovlap_thresh = np.zeros((orig_ovlap.shape[0], orig_ovlap.shape[1], 3), dtype='bool')
            img_ovlap_thresh[:, :, 1] = mask_ovlap  # Green: thresholded mask overlap
            orig_ovlap_x = orig_ovlap ^ mask_ovlap  # XOR for rejected mask overlap
            img_ovlap_thresh[:, :, np.array([0, 2])] = orig_ovlap_x[:, :, np.newaxis]  # Magenta: rejected overlap
            img_ovlap_thresh = img_ovlap_thresh.astype('uint8') * 255
            ax[0].imshow(img_ovlap_thresh)
            ax[0].set_title('Green=overlap, Magenta=rejected overlap')

            if save_mask_overlap_img:
                Path(img_save_dir).mkdir(exist_ok=True)
                img_save_fname = img_save_dir + '/' + sample_i + '_ovlap_' + '+'.join(channel_comb) \
                                 + '_mask.png'
                plt.tight_layout()
                plt.savefig(img_save_fname)
                plt.close(fig)

    return df_overlap


def mask_npy_to_img(npy_dir, save_dir, img_ext='.png'):
    """
    Convert masks in .npy format in a directory into image format. Useful for loading into cellpose GUI for further
    mask labelling.

    Parameters
    ----------
    npy_dir: str
        Directory of where the mask .pny files are
    save_dir: str
        Subdirectory within npy_dir of where image files are to be saved
    img_ext: str
        Image file extension to be saved (defaults '.png')

    Returns
    -------
    None
    """

    Path(save_dir).mkdir(exist_ok=True, parents=True)

    for p in Path(npy_dir).iterdir():
        fname_i, file_ext_i = os.path.splitext(os.path.basename(p))
        if file_ext_i == '.npy':
            mask = np.load(str(p), allow_pickle=True)
            save_path = save_dir + fname_i + img_ext
            io.imsave(save_path, mask, check_contrast=False)
            print(fname_i + img_ext + ' saved')


def threshold_masks(img, masks, diameter, mask_threshold_dict, min_size_factor=0.65):
    """
    Threshold masks based on original image. Each mask is checked against 1) a size threshold, and
    2) a user defined threshold function, called by mask_threshold_dict['threshold_func'].

    Parameters
    ----------
    img: ndarray of int
        2D image array in grayscale
    masks: ndarray of int
        2D array of masks
    diameter: float
        Minimum diameter of cells in pixels
    mask_threshold_dict: dict
        Dictionary containing mask threshold parameters.
    min_size_factor: float
        Scalar: factor for scaling minimum cell size based on diameter

    Returns
    -------
    masks_o, mask_avg_int, mask_npix:
        2D array of thresholded masks, list of average intensity for each mask, list of number of pixels
        for each mask
    """

    threshold_func = mask_threshold_dict['threshold_func']
    img_c = img.copy()
    masks_c = masks.copy()

    if img_c.ndim != 2:
        raise ValueError("Image must be a 2D array.")
    if masks_c.ndim != 2:
        raise ValueError("Masks must be a 2D array.")

    size_thld = math.floor(min_size_factor * np.pi * (diameter / 2) ** 2)  # calculate mask pixel size threshold

    masks_o = masks_c.copy()  # copy masks for output
    masks_avg_int = []
    masks_npix = []

    mask_idx = np.unique(masks_c)  # retrieve masks' ID
    mask_idx = np.delete(mask_idx, mask_idx == 0)  # remove 0 (non-mask) from mask index

    k_ok = 1  # initiate mask counter

    for k in mask_idx:

        mask_k = (masks_c == k)
        mask_accept, mask_avg_int, mask_npix = threshold_func(
            img_c, masks_c, k, size_thld, mask_threshold_dict)

        if mask_accept:
            masks_o[mask_k] = k_ok  # update original mask counter
            k_ok = k_ok + 1  # increment mask counter
            masks_avg_int.append(mask_avg_int)
            masks_npix.append(mask_npix)
        else:
            masks_o[mask_k] = 0

    return masks_o, masks_avg_int, masks_npix


def calc_iou(mask1, mask2):
    """
    Calculate intersection over union of 2 boolean arrays of the same shape.

    Parameters
    ----------
    mask1: ndarray of bool
        Array of boolean
    mask2: ndarray of bool
        Array of boolean, same shape as mask1
    Returns
    -------
    iou: float
        Intersection over union
    """

    mask1_area = np.sum(mask1)
    mask2_area = np.sum(mask2)
    intersect_mask_loc = (mask1 & mask2)
    intersect_mask_area = np.sum(intersect_mask_loc)
    union_mask_area = mask1_area + mask2_area - intersect_mask_area  # formula from inc-exc principle
    iou = intersect_mask_area / union_mask_area

    return iou


def get_iou_for_mask_id(gndtru_masks, ml_masks, gndtru_mask_index):
    """
    Given a 2D array of ground truth masks, a 2D array of machine-learned masks, and a ground truth
    mask index, find the intersection over union (IoU) with the closest matching ML mask.

    Parameters
    ----------
    gndtru_masks: ndarray
        2D array of mask indices from ground truth.
    ml_masks: ndarray
        2D array of mask indices from machine learning output.
    gndtru_mask_index: int
        Index of ground truth mask whose IoU will be calculated.

    Returns
    -------
    iou: float
    ml_mask_match_id: int
        The mask index of closest matching ML mask
    """

    ml_masks = ml_masks.copy()
    gt_mask_loc = (gndtru_masks == gndtru_mask_index)
    ml_mask_match = ml_masks[gt_mask_loc]
    ml_mask_count = np.bincount(ml_mask_match)
    ml_mask_count[0] = 0  # force non-mask bincount to 0
    ml_mask_match_id = np.argmax(np.bincount(ml_mask_match))

    if ml_mask_match_id != 0:
        ml_mask_loc = (ml_masks == ml_mask_match_id)
    else:
        ml_mask_loc = np.zeros(ml_masks.shape, dtype='bool')  # if no ML mask match, return all False

    iou = img_util.util.calc_iou(gt_mask_loc, ml_mask_loc)

    return iou, ml_mask_match_id


def validate_ml_mask_batch(gndtru_masks, ml_masks, iou_threshold, gndtru_masks_fname):
    """
    Given a 2D array of ground truth masks and a 2D array of machine-learned masks, calculate validation indices
    using a intersection over union (IoU) threshold.

    Parameters
    ----------
    gndtru_masks: ndarray
        2D array of mask indices from ground truth.
    ml_masks: ndarray
        2D array of mask indices from machine learning output.
    iou_threshold: float
        IoU threshold.
    gndtru_masks_fname: str
        Ground truth masks filename, used in output dataframe.


    Returns
    -------
    df_sample: dataframe
        Dataframe of 1 row. Contains number of true positives, false positives and true negatives. Also
        contains IoU for each ground truth mask.
    """

    gndtru_masks_ls = np.unique(gndtru_masks)
    gndtru_masks_ls = gndtru_masks_ls[gndtru_masks_ls > 0]
    ml_masks_remain = np.unique(ml_masks)
    ml_masks_remain = ml_masks_remain[ml_masks_remain > 0]
    iou_ls = np.zeros(gndtru_masks_ls.shape)

    for i, gt_mask_i in enumerate(gndtru_masks_ls):
        iou, ml_mask_id = img_util.util.get_iou_for_mask_id(gndtru_masks, ml_masks, gt_mask_i)
        iou_ls[i] = iou
        ml_masks_remain[ml_masks_remain == ml_mask_id] = 0

    num_fal_pos = np.sum(ml_masks_remain > 0)
    num_tru_pos = np.sum(iou_ls >= iou_threshold)
    num_fal_neg = np.sum(iou_ls < iou_threshold)

    df_sample = pd.DataFrame({'sample_name': [gndtru_masks_fname],
                              'num_true_pos': num_tru_pos,
                              'num_false_pos': num_fal_pos,
                              'num_false_neg': num_fal_neg})
    df_masks = pd.DataFrame(iou_ls[np.newaxis, :], columns=gndtru_masks_ls)
    df_sample = df_sample.join(df_masks)

    return df_sample


def get_validation_score(df_validate, channel_name):
    """
    Calculate precision, recall, average precision metric, and F1-score. Returns a dict.

    Parameters
    ----------
    df_validate: dataframe
    channel_name: str
        Channel name of masks segmented by ML.

    Returns
    -------
    df_validation_score: dataframe
    """

    tot_tru_pos = np.sum(df_validate['num_true_pos'])
    tot_fal_pos = np.sum(df_validate['num_false_pos'])
    tot_fal_neg = np.sum(df_validate['num_false_neg'])
    precision = tot_tru_pos / (tot_tru_pos + tot_fal_pos)
    recall = tot_tru_pos / (tot_tru_pos + tot_fal_neg)
    avg_precision_metric = tot_tru_pos / (tot_tru_pos + tot_fal_pos + tot_fal_neg)
    f1_score = tot_tru_pos / (tot_tru_pos + (tot_fal_pos + tot_fal_neg) / 2)

    df_validation_score = pd.DataFrame({'name': [channel_name],
                                        'precision': precision,
                                        'recall': recall,
                                        'avg precision metric': avg_precision_metric,
                                        'f1-score': f1_score})

    return df_validation_score


def rois_to_multi_cell_masks(roi_path):
    """
    Converts an ImageJ ROI .zip file that contains multiple-cell ROIs to a numpy array with segmented masks. The ImageJ
    ROI file should contain separate ROIs for each cell, drawn using the 'freehand' tool, and a rectangle ROI enclosing
    the whole original image (and thus contains its width and height dimensions).


    Parameters
    ----------
    roi_path: str
        File path of ROI .zip.

    Returns
    -------
    roi_masks: ndarray
        2D array of mask indices, converted from ROI .zip
    """

    file_name = os.path.basename(roi_path)
    print('Processing "' + file_name + '"')

    if os.path.splitext(file_name)[1] != '.zip':
        raise ValueError("Multi-cell ROI needs to be a .zip file")

    roi_all, _ = img_util.util.get_roi(roi_path)

    roi_type = [i['type'] for i in roi_all]

    n_roi_rect = roi_type.count('rectangle')
    if n_roi_rect != 1:
        raise ValueError(
            "Multi-cell ROI file needs to contain 1 rectangle ROI of the image's dimension. Instead it contains " + str(
                n_roi_rect) +
            " rectangle ROI.")

    roi_rect_i = roi_type.index('rectangle')
    img_x, img_y = roi_all[roi_rect_i]['height'], roi_all[roi_rect_i]['width']

    roi_masks = np.zeros([img_x, img_y], np.uint8)

    mask_k = 1
    for roi in roi_all:

        # retrieve each freehand ROI as mask outline, then fill
        if roi['type'] == 'freehand':
            roi_mask_blank = np.zeros([img_x, img_y], np.uint8)
            my_x = roi['x']
            my_y = roi['y']
            my_x = [round(j) for j in my_x]
            my_y = [round(j) for j in my_y]

            # join the two ends of the ROI
            diff_x = abs(my_x[0] - my_x[-1]) - 1
            diff_y = abs(my_y[0] - my_y[-1]) - 1
            diff_x = max(diff_x, 0)
            diff_y = max(diff_y, 0)
            if (diff_x > 0) | (diff_y > 0):  # if there are gaps between two ends of ROI

                if diff_y > diff_x:
                    newpix_y = list(range(my_y[0] + 1, my_y[-1]))  # index of y pixels to fill in
                    slope_xy = (my_x[-1] - my_x[0]) / (my_y[-1] - my_y[0])
                    newpix_x = [my_x[0]] * len(
                        newpix_y)  # generate starting index of x pixels, with correct number of elements
                    for i, j in enumerate(newpix_x):
                        newpix_x[i] = round(newpix_x[i] + (i * slope_xy))

                else:
                    newpix_x = list(range(my_x[0] + 1, my_x[-1]))  # index of x pixels to fill in
                    slope_yx = (my_y[-1] - my_y[0]) / (my_x[-1] - my_x[0])
                    newpix_y = [my_y[0]] * len(
                        newpix_x)  # generate starting index of y pixels, with correct number of elements
                    for i, j in enumerate(newpix_y):
                        newpix_y[i] = round(newpix_y[i] + (i * slope_yx))

                my_x = my_x + newpix_x
                my_y = my_y + newpix_y

            # new ROI outline with no gaps. Note in ImageJ, x = cols (width), y = rows (height)
            roi_mask_blank[my_y, my_x] = 1
            roi_mask_blank = binary_fill_holes(roi_mask_blank)
            mask_pix = (roi_mask_blank != 0)
            roi_masks[mask_pix] = mask_k
            mask_k = mask_k + 1

    return roi_masks
